import csop_helper as csop
import numpy as np

def parse_solution(solution):
    soln_list = [(x[1],x[0]) for x in list(solution.items()) if len(x[1]) > 0]
    soln_dict = dict(zip(csop.list_flatten([x[0] for x in soln_list]), csop.list_flatten([[x[1]]*len(x[0]) for x in soln_list])))
    
    return dict(sorted(soln_dict.items()))


def check_violations(soln_dict, constraints):
    violations = 0 
    active_dict = {k:v for k,v in soln_dict.items() if v != "deck"}
    active_students = set(active_dict.keys())
    active_constraints = [constraint for constraint in constraints if set(constraint['pair']).issubset(active_students)]
    
    for constraint in active_constraints: 
        
        if constraint['type'] == 0:
            if soln_dict[constraint['pair'][0]] != soln_dict[constraint['pair'][1]]:
                violations += 1
        
        elif constraint['type'] == 1:
            if soln_dict[constraint['pair'][0]] == soln_dict[constraint['pair'][1]]:
                violations += 1
        
        elif constraint['type'] == 2:
            if abs(int(soln_dict[constraint['pair'][0]]) - int(soln_dict[constraint['pair'][1]])) != 1:
                violations += 1
        
        elif constraint['type'] == 3: 
            if abs(int(soln_dict[constraint['pair'][0]]) - int(soln_dict[constraint['pair'][1]])) <= 1:
                violations += 1 
                
                
    
    return violations
            

def score_solution(soln_dict, payoff, constraints, students):
    nominal_score = 0
    for student in students:
        if soln_dict[student] != "deck": 
            nominal_score += payoff[student][int(soln_dict[student])]
    
    num_violations = check_violations(soln_dict, constraints)
    
    real_score = nominal_score - 100*num_violations
    
    return (nominal_score, real_score)



def adjacent_solutions(soln_dict, rooms, students):
    adjacent_solns = []
    for student in students:
        adjacent_solns.extend([{**soln_dict, **{student:room}} for room in rooms if soln_dict[student] != room])
        
    return adjacent_solns


def get_soln_pctiles(task_index, parsed_intermediateSolutions, phase):
    students = csop.task_specs['phase_{}'.format(phase)][task_index]['students']
    payoff = csop.task_specs['phase_{}'.format(phase)][task_index]['payoff']
    constraints = csop.task_specs['phase_{}'.format(phase)][task_index]['constraints']
    rooms = [str(x) for x in csop.task_specs['phase_{}'.format(phase)][task_index]['rooms']]
    
    solution_dicts = [parse_solution(x['solution']) for x in parsed_intermediateSolutions]
    
    nominal_soln_pctiles = []
    real_soln_pctiles = []
    
    for index, soln_dict in enumerate(solution_dicts): 
        if index == 0:
            nominal_soln_pctiles.append(None)
            real_soln_pctiles.append(None)

        else: 
            #set the index of the solution to 0 for nominal, 1 for real 
            nominal_soln_score, real_soln_score = score_solution(soln_dict, payoff, constraints, students)
            adjacent_scores = [score_solution(adj_soln_dict, payoff, constraints, students) for adj_soln_dict in adjacent_solutions(solution_dicts[index-1], rooms, students)]
            nominal_adjacent_scores = [x[0] for x in adjacent_scores]
            real_adjacent_scores = [x[1] for x in adjacent_scores]
            
#             soln_pctiles.append([np.mean(np.array(nominal_adjacent_scores) <= nominal_soln_score), np.mean(np.array(real_adjacent_scores) <= real_soln_score)])
            
            nominal_soln_pctiles.append(np.mean(np.array(nominal_adjacent_scores) <= nominal_soln_score))
            real_soln_pctiles.append(np.mean(np.array(real_adjacent_scores) <= real_soln_score))
    
    return [nominal_soln_pctiles, real_soln_pctiles]